from models.AlexNet_model import AlexNet
from models.LeNet_model import LeNet
from models.VGG_models import VGG
from models.resnet import resnet


def select_model(model_name):
    model = None
    if model_name == 'lenet':
        model = LeNet()
    elif model_name == 'AlexNet':
        model = AlexNet()
    elif model_name == 'VGG_16':
        model = VGG()
    elif model_name == 'resnet':
        model = resnet()
    return model
